#include "read_wfc_nao.h"

#include "source_base/parallel_common.h"
#include "source_base/timer.h"
#include "source_io/write_wfc_nao.h"

#include "write_wfc_nao.h"
#include "source_base/module_external/scalapack_connector.h"
#include "source_io/filename.h"
#include "source_base/tool_title.h" // use title
#include "source_base/global_function.h" // use READ_VALUE

// mohan add 2025-10-19
void ModuleIO::read_wfc_nao_one_data(std::ifstream& ifs, float& data)
{
    ifs >> data;
}

void ModuleIO::read_wfc_nao_one_data(std::ifstream& ifs, double& data)
{
    ifs >> data;
}

void ModuleIO::read_wfc_nao_one_data(std::ifstream& ifs, std::complex<double>& data)
{
    double a = 0.0;
    double b = 0.0;
    ifs >> a >> b;
    data = std::complex<double>(a, b);
}

void ModuleIO::read_wfc_nao_one_data(std::ifstream& ifs, std::complex<float>& data)
{
    float a = 0.0;
    float b = 0.0;
    ifs >> a >> b;
    data = std::complex<float>(a, b);
}

template <typename T>
bool ModuleIO::read_wfc_nao(
    const std::string& global_readin_dir,
    const Parallel_Orbitals& ParaV,
    psi::Psi<T>& psid,
	ModuleBase::matrix& ekb,
    ModuleBase::matrix& wg,
	const std::vector<int> &ik2iktot,
	const int nkstot,
	const int nspin,
    const int skip_band,
    const int istep)
{
    ModuleBase::TITLE("ModuleIO", "read_wfc_nao");
    ModuleBase::timer::tick("ModuleIO", "read_wfc_nao");

    const int nk = ekb.nr;

    const bool gamma_only = std::is_same<T, double>::value;
    const int out_type = 1; // only support .txt file now
    bool read_success = true;
    int myrank = 0;
    int nbands = ParaV.get_wfc_global_nbands(); // the global number of bands
    int nlocal = ParaV.get_wfc_global_nbasis(); // the global number of basis functions
    int nbands_local = ParaV.ncol_bands; // the number of bands in the local process
    int nlocal_local = ParaV.nrow_bands; // the number of basis functions in the local process

    if (gamma_only)
    {
        // I don't know why, but in orther places, the init of psi is using ParaV.ncol for gamma_only case
        // It seems that the diagonalization of double case need a full matrix of psi.
        nbands_local = ParaV.ncol;
    }
    psid.resize(nk, nbands_local, nlocal_local);

#ifdef __MPI
    MPI_Comm_rank(ParaV.comm(), &myrank);
#endif   

    // lambda function to read one file
	auto read_one_file = [&](const std::string& ss, 
			std::stringstream& error_message, 
			const int ik, 
			std::vector<T>& ctot)
    {
        std::ifstream ifs;
        ifs.open(ss.c_str());
        if (!ifs)
        {
            error_message << " Can't open file:" << ss << std::endl;
            return false;
        }
        else
		{
            std::cout << " Read NAO wave functions from " << ss << std::endl;
		}

        if (!gamma_only)
        {
            int ik_file = 0;
			double kx = 0.0;
			double ky = 0.0;
			double kz = 0.0;
			ModuleBase::GlobalFunc::READ_VALUE(ifs, ik_file);
            ifs >> kx >> ky >> kz;
            if (ik_file != ik + 1)
            {
                error_message << "The k index read in from file do not match the k index generated by ABACUS!\n";
                error_message << " read in k index=" << ik_file;
                error_message << " ABACUS k index =" << ik+1 << std::endl;
                ifs.close();
                return false;
            }
        }
        int nbands_file = 0, nlocal_file = 0;
        ModuleBase::GlobalFunc::READ_VALUE(ifs, nbands_file);
        ModuleBase::GlobalFunc::READ_VALUE(ifs, nlocal_file);
        if (nbands > nbands_file)
        {
            error_message << "The number of bands to be read exceeds the number of bands in the file generated by ABACUS!\n";
            error_message << " nbands in the existing file=" << nbands_file;
            error_message << " nbands to be read into ABACUS=" << nbands << std::endl;
            ifs.close();
            return false;
        }
        if (nlocal != nlocal_file)
        {
            error_message << "The nlocal read in from file do not match the nlocal generated by ABACUS!\n";
            error_message << " read in nlocal=" << nlocal_file;
            error_message << " ABACUS nlocal =" << nlocal << std::endl;
            ifs.close();
            return false;
        }
        for (int i = 0; i < skip_band + nbands; i++)
        {
            // the first skip_bands useless bands are read into 0th band to be overwritten
            const int ib_read = std::max(i - skip_band, 0);
            int ib = 0;
            ModuleBase::GlobalFunc::READ_VALUE(ifs, ib);
            ModuleBase::GlobalFunc::READ_VALUE(ifs, ekb(ik, ib_read));
            ModuleBase::GlobalFunc::READ_VALUE(ifs, wg(ik, ib_read));
            if (i+1 != ib)
            {
                error_message << "The band index read in from file do not match the global parameter band index!\n";
                error_message << " read in band index=" << ib;
                error_message << " band index=" << i+1 << std::endl;
                ifs.close();
                return false;
            }
            for (int j = 0; j < nlocal; j++)
            {
                read_wfc_nao_one_data(ifs, ctot[ib_read * nlocal + j]);
            }
        }
        ifs.close();
        return true;
    }; // end read one file
        

    std::string errors;

	std::vector<T> ctot;
	if (myrank == 0) 
	{
		ctot.resize(nbands * nlocal);
	}
	else
	{
		ctot.resize(0);
	}

    for(int ik=0;ik<nk;ik++)
    {
        if (myrank == 0)
        {
            const bool out_app_flag = false;
            std::stringstream error_message;
            std::string readin_dir = global_readin_dir;
            if(istep >= 0)
            {
                readin_dir = readin_dir + "WFC/";
            }
            std::string ss = ModuleIO::filename_output(readin_dir,"wf","nao",
                    ik,ik2iktot,nspin,nkstot,out_type,out_app_flag,gamma_only,istep);

            read_success = read_one_file(ss, error_message, ik, ctot);
            errors = error_message.str();
        }   
#ifdef __MPI
        Parallel_Common::bcast_bool(read_success);
        Parallel_Common::bcast_string(errors);
#endif 
        if (!read_success)
        {
            std::cout << " Error in reading wave function files!\n";
            std::cout << errors << std::endl;
            return false;
        }

        psid.fix_k(ik);
#ifdef __MPI
        Parallel_2D pv_glb;
        pv_glb.set(nlocal, nbands, std::max(nlocal, nbands), ParaV.blacs_ctxt);
        Cpxgemr2d(nlocal,
                  nbands,
                  ctot.data(),
                  1,
                  1,
                  pv_glb.desc,
                  psid.get_pointer(),
                  1,
                  1,
                  const_cast<int*>(ParaV.desc_wfc),
                  pv_glb.blacs_ctxt);
        Parallel_Common::bcast_double(&ekb(ik, 0), nbands);
        Parallel_Common::bcast_double(&wg(ik, 0), nbands);
#else
        BlasConnector::copy(nbands*nlocal, ctot.data(), 1, psid.get_pointer(), 1);
#endif
    }// end of loop over k-points
    
    return true;    
};

template bool ModuleIO::read_wfc_nao<double>(const std::string& global_readin_dir,
    const Parallel_Orbitals& ParaV,
    psi::Psi<double>& psid,
	ModuleBase::matrix& ekb,
    ModuleBase::matrix& wg,
	const std::vector<int> &ik2iktot,
	const int nkstot,
	const int nspin,
    const int istep,
    const int skip_band);

// mohan add 2025-10-19
template bool ModuleIO::read_wfc_nao<float>(const std::string& global_readin_dir,
    const Parallel_Orbitals& ParaV,
    psi::Psi<float>& psid,
	ModuleBase::matrix& ekb,
    ModuleBase::matrix& wg,
	const std::vector<int> &ik2iktot,
	const int nkstot,
	const int nspin,
    const int istep,
    const int skip_band);

template bool ModuleIO::read_wfc_nao<std::complex<double>>(const std::string& global_readin_dir,
    const Parallel_Orbitals& ParaV,
	psi::Psi<std::complex<double>>& psid,
	ModuleBase::matrix& ekb,
    ModuleBase::matrix& wg,
	const std::vector<int> &ik2iktot,
	const int nkstot,
	const int nspin,
    const int istep,
	const int skip_band);

// mohan add 2025-10-19
template bool ModuleIO::read_wfc_nao<std::complex<float>>(const std::string& global_readin_dir,
    const Parallel_Orbitals& ParaV,
	psi::Psi<std::complex<float>>& psid,
	ModuleBase::matrix& ekb,
    ModuleBase::matrix& wg,
	const std::vector<int> &ik2iktot,
	const int nkstot,
	const int nspin,
    const int istep,
	const int skip_band);
